package redis

import (
	"bytes"
	"context"
	"encoding/json"
	"fmt"
	"github.com/RichardKnop/machinery/v2/brokers/errs"
	"github.com/RichardKnop/machinery/v2/brokers/iface"
	"github.com/RichardKnop/machinery/v2/common"
	"github.com/RichardKnop/machinery/v2/config"
	"github.com/RichardKnop/machinery/v2/log"
	"github.com/RichardKnop/machinery/v2/tasks"
	"github.com/go-redis/redis/v8"
	"runtime"
	"strings"
	"sync"
	"time"
)

const defaultRedisDelayedTasksKey = "delayed_tasks"
const defaultRedisBroadcastTaskKey = "broadcast_tasks"
const defaultStreamBroadcastMsgKey = "task_signature"
const defaultBroadcastHeaderKey = "broadcastTask"

/*
 * 扩展支持任务广播的broker
 */
type BrokerBroadcast struct {
	*BrokerGR
	broadcastWG            sync.WaitGroup //用于等待广播任务处理协程结束
	redisBroadcastTasksKey string         //监听的广播任务队列名
	broadcastHeaderKey     string         //广播签名头
	lastBroadcastMsgId     string         //最新广播消息id
}

func New(cnf *config.Config, addrs []string, db int) *BrokerBroadcast {
	gr := &BrokerGR{Broker: common.NewBroker(cnf)}

	var password string
	parts := strings.Split(addrs[0], "@")
	if len(parts) == 2 {
		// with password
		password = parts[0]
		addrs[0] = parts[1]
	}

	ropt := &redis.UniversalOptions{
		Addrs:    addrs,
		DB:       db,
		Password: password,
	}
	if cnf.Redis != nil {
		ropt.MasterName = cnf.Redis.MasterName
	}

	gr.rclient = redis.NewUniversalClient(ropt)
	if cnf.Redis.DelayedTasksKey != "" {
		gr.redisDelayedTasksKey = cnf.Redis.DelayedTasksKey
	} else {
		gr.redisDelayedTasksKey = defaultRedisDelayedTasksKey
	}
	b := &BrokerBroadcast{
		BrokerGR:               gr,
		redisBroadcastTasksKey: defaultRedisBroadcastTaskKey,
		broadcastHeaderKey:     defaultBroadcastHeaderKey,
	}

	return b
}

func NewWithOptions(cnf *config.Config, addrs []string, db int, opts ...Option) *BrokerBroadcast {
	result := New(cnf, addrs, db)
	for _, opt := range opts {
		opt(result)
	}
	return result
}

// StartConsuming enters a loop and waits for incoming messages
func (b *BrokerBroadcast) StartConsuming(consumerTag string, concurrency int, taskProcessor iface.TaskProcessor) (bool, error) {
	b.consumingWG.Add(1)
	defer b.consumingWG.Done()

	if concurrency < 1 {
		concurrency = runtime.NumCPU() * 2
	}

	b.Broker.StartConsuming(consumerTag, concurrency, taskProcessor)

	// Ping the server to make sure connection is live
	_, err := b.rclient.Ping(context.Background()).Result()
	if err != nil {
		b.GetRetryFunc()(b.GetRetryStopChan())

		// Return err if retry is still true.
		// If retry is false, broker.StopConsuming() has been called and
		// therefore Redis might have been stopped. Return nil exit
		// StartConsuming()
		if b.GetRetry() {
			return b.GetRetry(), err
		}
		return b.GetRetry(), errs.ErrConsumerStopped
	}

	// Channel to which we will push tasks ready for processing by worker
	deliveries := make(chan []byte, concurrency)
	pool := make(chan struct{}, concurrency)

	// initialize worker pool with maxWorkers workers
	for i := 0; i < concurrency; i++ {
		pool <- struct{}{}
	}

	// A receiving goroutine keeps popping messages from the queue by BLPOP
	// If the message is valid and can be unmarshaled into a proper structure
	// we send it to the deliveries channel
	go func() {

		log.INFO.Print("[*] Waiting for messages. To exit press CTRL+C")

		for {
			select {
			// A way to stop this goroutine from b.StopConsuming
			case <-b.GetStopChan():
				close(deliveries)
				return
			case <-pool:
				task, _ := b.nextTask(getQueueGR(b.GetConfig(), taskProcessor))
				//TODO: should this error be ignored?
				if len(task) > 0 {
					deliveries <- task
				}

				pool <- struct{}{}
			}
		}
	}()

	// A goroutine to watch for delayed tasks and push them to deliveries
	// channel for consumption by the worker
	b.delayedWG.Add(1)
	go func() {
		defer b.delayedWG.Done()

		for {
			select {
			// A way to stop this goroutine from b.StopConsuming
			case <-b.GetStopChan():
				return
			default:
				task, err := b.nextDelayedTask(b.redisDelayedTasksKey)
				if err != nil {
					continue
				}

				signature := new(tasks.Signature)
				decoder := json.NewDecoder(bytes.NewReader(task))
				decoder.UseNumber()
				if err := decoder.Decode(signature); err != nil {
					log.ERROR.Print(errs.NewErrCouldNotUnmarshalTaskSignature(task, err))
				}

				if err := b.Publish(context.Background(), signature); err != nil {
					log.ERROR.Print(err)
				}
			}
		}
	}()

	// 监听广播任务的协程
	// 任务到达时直接投递给消费者
	b.broadcastWG.Add(1)
	go func() {
		defer b.broadcastWG.Done()

		for {
			select {
			// 监听消费者是否结束
			case <-b.GetStopChan():
				return
			default:
				task, _ := b.nextBroadCastTask(b.redisBroadcastTasksKey)
				if len(task) > 0 {
					deliveries <- task
				}
			}
		}
	}()

	if err := b.consume(deliveries, concurrency, taskProcessor); err != nil {
		return b.GetRetry(), err
	}

	// Waiting for any tasks being processed to finish
	b.processingWG.Wait()

	return b.GetRetry(), nil
}

// StopConsuming quits the loop
func (b *BrokerBroadcast) StopConsuming() {
	b.Broker.StopConsuming()
	// Waiting for the delayed tasks goroutine to have stopped
	b.delayedWG.Wait()
	// Waiting for the broadcast tasks goroutine to have stopped
	b.broadcastWG.Wait()
	// Waiting for consumption to finish
	b.consumingWG.Wait()

	b.rclient.Close()
}

// Publish places a new message on the default queue
func (b *BrokerBroadcast) Publish(ctx context.Context, signature *tasks.Signature) error {
	// Adjust routing key (this decides which queue the message will be published to)
	b.Broker.AdjustRoutingKey(signature)

	msg, err := json.Marshal(signature)
	if err != nil {
		return fmt.Errorf("JSON marshal error: %s", err)
	}

	//借助自定义签名头来存储广播队列名
	if v, ok := signature.Headers[b.broadcastHeaderKey]; ok {
		broadcastTaskKey, ok := v.(string)
		//未指定广播队则使用系统配置的
		if !ok || broadcastTaskKey == "" {
			broadcastTaskKey = b.redisBroadcastTasksKey
		}
		//广播队列默认上限为 MAXLEN ~ 1000
		maxlen := int64(1000)
		err = b.rclient.XAdd(context.Background(), &redis.XAddArgs{
			Stream:       broadcastTaskKey,
			MaxLenApprox: maxlen,
			ID:           "*",
			Values:       map[string]interface{}{defaultStreamBroadcastMsgKey: string(msg)},
		}).Err()
		return err
	}

	// Check the ETA signature field, if it is set and it is in the future,
	// delay the task
	if signature.ETA != nil {
		now := time.Now().UTC()

		if signature.ETA.After(now) {
			score := signature.ETA.UnixNano()
			err = b.rclient.ZAdd(context.Background(), b.redisDelayedTasksKey, &redis.Z{Score: float64(score), Member: msg}).Err()
			return err
		}
	}

	err = b.rclient.RPush(context.Background(), signature.RoutingKey, msg).Err()
	return err
}

func (b *BrokerBroadcast) nextBroadCastTask(queue string) (result []byte, err error) {
	//默认拉取消息间隔为1000ms
	pollPeriodMilliseconds := 1000
	if b.GetConfig().Redis != nil {
		configuredPollPeriod := b.GetConfig().Redis.NormalTasksPollPeriod
		if configuredPollPeriod > 0 {
			pollPeriodMilliseconds = configuredPollPeriod
		}
	}
	pollPeriod := time.Duration(pollPeriodMilliseconds) * time.Millisecond

	if b.lastBroadcastMsgId == "" {
		//若客户端记录的最新消息id为空，则从流中最新的消息id
		msgs, err := b.rclient.XRevRangeN(context.Background(), queue, "+", "-", 1).Result()
		if err != nil {
			return []byte{}, err
		}
		if len(msgs) == 0 {
			b.lastBroadcastMsgId = "0"
			return []byte{}, redis.Nil
		}
		b.lastBroadcastMsgId = msgs[0].ID
	}

	//消费广播消息，一次一条
	streams, err := b.rclient.XRead(context.Background(), &redis.XReadArgs{
		Streams: []string{queue, b.lastBroadcastMsgId},
		Count:   1,
		Block:   pollPeriod,
	}).Result()
	if err != nil {
		return []byte{}, err
	}

	if len(streams) == 0 || len(streams[0].Messages) == 0 {
		return []byte{}, redis.Nil
	}
	msg := streams[0].Messages[0]
	b.lastBroadcastMsgId = msg.ID
	signatureV := msg.Values[defaultStreamBroadcastMsgKey]
	signatureStr, ok := signatureV.(string)
	if ok {
		return []byte(signatureStr), nil
	}

	return []byte{}, fmt.Errorf("not support msg type")
}
